{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import LSTM,Dropout\n",
    "from keras.layers import Dense\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import decomposition\n",
    "from sklearn.feature_selection import SelectKBest\n",
    "from sklearn.feature_selection import chi2\n",
    "from sklearn.pipeline import FeatureUnion\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import matplotlib.pyplot as plt\n",
    "import scikitplot as skplt\n",
    "from sklearn.preprocessing import label_binarize\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from scipy import interp\n",
    "from itertools import cycle\n",
    "import time\n",
    "from keras import metrics\n",
    "from keras import optimizers\n",
    "from keras import initializers\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "import pprint\n",
    "from keras.callbacks import Callback\n",
    "from keras.layers import Bidirectional\n",
    "from keras.layers import TimeDistributed\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "kdd_dataset = pd.read_csv('KDD.preProcessed.csv')\n",
    "#kdd_dataset.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>service</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>num_failed_logins</th>\n",
       "      <th>logged_in</th>\n",
       "      <th>...</th>\n",
       "      <th>flag_S1</th>\n",
       "      <th>flag_S2</th>\n",
       "      <th>flag_S3</th>\n",
       "      <th>flag_SF</th>\n",
       "      <th>flag_SH</th>\n",
       "      <th>label_DoS</th>\n",
       "      <th>label_Normal</th>\n",
       "      <th>label_Probe</th>\n",
       "      <th>label_R2L</th>\n",
       "      <th>label_U2R</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.347826</td>\n",
       "      <td>0.030859</td>\n",
       "      <td>0.846454</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.217391</td>\n",
       "      <td>0.072004</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.217391</td>\n",
       "      <td>0.143175</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.420290</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.710145</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 58 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   duration   service  src_bytes  dst_bytes  land  wrong_fragment  urgent  \\\n",
       "0       0.0  0.347826   0.030859   0.846454   0.0             0.0     0.0   \n",
       "1       0.0  0.217391   0.072004   0.000000   0.0             0.0     0.0   \n",
       "2       0.0  0.217391   0.143175   0.000000   0.0             0.0     0.0   \n",
       "3       0.0  0.420290   0.000000   0.000000   0.0             0.0     0.0   \n",
       "4       0.0  0.710145   0.000000   0.000000   0.0             0.0     0.0   \n",
       "\n",
       "   hot  num_failed_logins  logged_in    ...      flag_S1  flag_S2  flag_S3  \\\n",
       "0  0.0                0.0        1.0    ...          0.0      0.0      0.0   \n",
       "1  0.0                0.0        0.0    ...          0.0      0.0      0.0   \n",
       "2  0.0                0.0        0.0    ...          0.0      0.0      0.0   \n",
       "3  0.0                0.0        0.0    ...          0.0      0.0      0.0   \n",
       "4  0.0                0.0        0.0    ...          0.0      0.0      0.0   \n",
       "\n",
       "   flag_SF  flag_SH  label_DoS  label_Normal  label_Probe  label_R2L  \\\n",
       "0      1.0      0.0        0.0           1.0          0.0        0.0   \n",
       "1      1.0      0.0        1.0           0.0          0.0        0.0   \n",
       "2      1.0      0.0        1.0           0.0          0.0        0.0   \n",
       "3      0.0      0.0        1.0           0.0          0.0        0.0   \n",
       "4      0.0      0.0        0.0           0.0          1.0        0.0   \n",
       "\n",
       "   label_U2R  \n",
       "0        0.0  \n",
       "1        0.0  \n",
       "2        0.0  \n",
       "3        0.0  \n",
       "4        0.0  \n",
       "\n",
       "[5 rows x 58 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kdd_dataset.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "features = list(kdd_dataset)\n",
    "features.remove('label_DoS')\n",
    "features.remove('label_Normal')\n",
    "features.remove('label_Probe')\n",
    "features.remove('label_R2L')\n",
    "features.remove('label_U2R')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "DoS = kdd_dataset.loc[kdd_dataset['label_DoS'] == 1]\n",
    "Normal = kdd_dataset.loc[kdd_dataset['label_Normal'] == 1]\n",
    "Probe = kdd_dataset.loc[kdd_dataset['label_Probe'] == 1]\n",
    "R2L = kdd_dataset.loc[kdd_dataset['label_R2L'] == 1]\n",
    "U2R = kdd_dataset.loc[kdd_dataset['label_U2R'] == 1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#Reduce oversampled data by a factor of 40. Same class distribution maintained after reduction.\n",
    "\n",
    "DoS = DoS.sample(n=97084,random_state = 42)\n",
    "Normal = Normal.sample(n=24320, random_state = 42)\n",
    "Probe = Probe.sample(n=1027, random_state=42)\n",
    "R2L = R2L.sample(n=1126, random_state=42)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Optional oversampling of undersampled data.\n",
    "\n",
    "#U2R = U2R.append([U2R]*500,ignore_index=True)\n",
    "#R2L = R2L.append([R2L]*100,ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "reduced_dataset = pd.concat([DoS,Normal,Probe,R2L,U2R])\n",
    "kdd_dataset = reduced_dataset.sample(n=len(reduced_dataset), random_state = 42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of datapoints = 123579 and number of features = 53\n"
     ]
    }
   ],
   "source": [
    "\n",
    "x = kdd_dataset[features].values\n",
    "y = kdd_dataset.iloc[:,53:].values\n",
    "print(\"number of datapoints = {} and number of features = {}\".format(len(x),len(x[0])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counter({(1.0, 0.0, 0.0, 0.0, 0.0): 97084, (0.0, 1.0, 0.0, 0.0, 0.0): 24320, (0.0, 0.0, 0.0, 1.0, 0.0): 1126, (0.0, 0.0, 1.0, 0.0, 0.0): 1027, (0.0, 0.0, 0.0, 0.0, 1.0): 22})\n"
     ]
    }
   ],
   "source": [
    "# Counting occurrences\n",
    "from collections import Counter\n",
    "print(Counter([tuple(x) for x in y]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape before transformation - (123579, 53)\n",
      "Shape after transformation - (123579, 10)\n"
     ]
    }
   ],
   "source": [
    "#PCA\n",
    "\n",
    "print(\"Shape before transformation - {}\".format(np.asarray(x).shape))\n",
    "pca = decomposition.PCA(n_components=10)\n",
    "pca.fit(x)\n",
    "x_pca = pca.transform(x)\n",
    "print(\"Shape after transformation - {}\".format(x_pca.shape))\n",
    "x_pca = x_pca.tolist()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Train:Test:Val split - 60:20:20\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_tr, X_val, Y_tr, Y_val = train_test_split(x_pca, y,\n",
    "                                                    stratify=y, \n",
    "                                                    test_size=0.2,\n",
    "                                                    random_state=42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_train, X_test, Y_train, Y_test = train_test_split(X_tr, Y_tr,\n",
    "                                                    stratify=Y_tr, \n",
    "                                                    test_size=0.25,\n",
    "                                                    random_state=42)\n",
    "max_len = len(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def processTrainTestArrays(x,y):\n",
    "    x=np.asarray(x)\n",
    "    y=np.asarray(y)\n",
    "    x = np.reshape(x, (x.shape[0], 1, x.shape[1]))\n",
    "    y=np.reshape(y, (y.shape[0], 1, y.shape[1]))\n",
    "    #x = np.reshape(x, (x.shape[0],x.shape[1],1))\n",
    "    #y = np.reshape(y, (y.shape[0], y.shape[1]))\n",
    "    return x,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def onehotencode(y):\n",
    "    from keras.utils.np_utils import to_categorical\n",
    "    y_binary = to_categorical(y)\n",
    "    return y_binary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# define model\n",
    "\n",
    "def create_model(X_train,Y_train):\n",
    "    Batch_size = 1\n",
    "    number_of_units = 20\n",
    "    randomInit = initializers.RandomUniform(seed=42)\n",
    "    \n",
    "    model = Sequential()\n",
    "    model.add(Bidirectional(LSTM(number_of_units,stateful=True,return_sequences=True,\n",
    "                                 kernel_initializer = randomInit, bias_initializer = randomInit,), \n",
    "                            batch_input_shape=(Batch_size,X_train.shape[1], X_train.shape[2])))\n",
    "    model.add(Dropout(0.15))\n",
    "    model.add(TimeDistributed(Dense(5,activation='softmax')))\n",
    "    \n",
    "    \n",
    "    nadam = optimizers.Nadam(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)\n",
    "    model.compile(loss='categorical_crossentropy', optimizer=nadam, metrics=['accuracy'])\n",
    "\n",
    "    \n",
    "    print(model.summary())\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def compute_metrics(predicted, Y_test):\n",
    "   \n",
    "    Y_classes = [np.argmax(item) for item in Y_test]\n",
    "    predicted_classes = [np.argmax(item) for item in predicted]\n",
    "    \n",
    "    predictions = np.array(predicted_classes)\n",
    "    Y_actual = np.array(Y_classes)\n",
    "   \n",
    "    y_actu = pd.Series(np.reshape(Y_actual,len(Y_actual)), name='Actual')\n",
    "    y_pred = pd.Series(np.reshape(predictions, len(predictions)), name='Predicted')\n",
    "    \n",
    "    y_act = pd.Categorical(y_actu, categories=[0,1,2,3,4])\n",
    "    y_pre = pd.Categorical(y_pred, categories=[0,1,2,3,4])\n",
    "    \n",
    "    #F1 score\n",
    "    from sklearn.metrics import f1_score\n",
    "    print(\"Average F1 score is {}\".format(f1_score(Y_classes, predicted_classes, average='weighted', labels = [0,1,2,3,4])))\n",
    "    \n",
    "    df_confusion = pd.crosstab(y_act, y_pre, rownames =['Actual'],colnames=['Predicted'])\n",
    "    print(\"Confusion matrix: \\n\", df_confusion)\n",
    "    #pprint.pprint(df_confusion)\n",
    "    predicted= np.array(predicted).tolist()\n",
    "    Y_test = np.array(Y_test).tolist()\n",
    "\n",
    "    n_classes = 5\n",
    "    Y_test= np.asarray(Y_test)\n",
    "    predicted = np.asarray(predicted)\n",
    "    \n",
    "    fpr = dict()\n",
    "    tpr = dict()\n",
    "    roc_auc = dict()\n",
    "    for i in range(n_classes):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_test[:, i], predicted[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    # Compute micro-average ROC curve and ROC area\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_test.ravel(), predicted.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "    \n",
    "    lw = 2\n",
    "    # Compute macro-average ROC curve and ROC area\n",
    "    #First aggregate all false positive rates\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))\n",
    "\n",
    "    # Then interpolate all ROC curves at this points\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(n_classes):\n",
    "        mean_tpr += interp(all_fpr, fpr[i], tpr[i])\n",
    "\n",
    "    # Finally average it and compute AUC\n",
    "    mean_tpr /= n_classes\n",
    "\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    # Plot all ROC curves\n",
    "    plt.figure()\n",
    "    plt.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "             label='micro-average ROC curve (area = {0:0.2f})'\n",
    "                   ''.format(roc_auc[\"micro\"]),\n",
    "             color='deeppink', linestyle=':', linewidth=4)\n",
    "\n",
    "    plt.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "             label='macro-average ROC curve (area = {0:0.2f})'\n",
    "                   ''.format(roc_auc[\"macro\"]),\n",
    "             color='navy', linestyle=':', linewidth=4)\n",
    "\n",
    "    colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])\n",
    "    for i, color in zip(range(n_classes), colors):\n",
    "        plt.plot(fpr[i], tpr[i], color=color, lw=lw,\n",
    "                 label='ROC curve of class {0} (area = {1:0.2f})'\n",
    "                 ''.format(i, roc_auc[i]))\n",
    "\n",
    "    plt.plot([0, 1], [0, 1], 'k--', lw=lw)\n",
    "    plt.xlim([0.0, 1.0])\n",
    "    plt.ylim([0.0, 1.05])\n",
    "    plt.xlabel('False Positive Rate')\n",
    "    plt.ylabel('True Positive Rate')\n",
    "    plt.legend(loc=\"lower right\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#checkpoint save weights\n",
    "\n",
    "hdf5FileName = \"bestWeightsBiLSTMStateful1L20.hdf5\"\n",
    "checkpoint = ModelCheckpoint(hdf5FileName,monitor='val_acc',verbose=1,save_best_only=True,mode='max',save_weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ResetStatesCallback(Callback):\n",
    "    def __init__(self):\n",
    "        self.counter = 0\n",
    "\n",
    "    def on_batch_begin(self, batch, logs={}):\n",
    "        if self.counter % max_len == 0:\n",
    "            self.model.reset_states()\n",
    "            print(\"Model reset. \",self.counter)\n",
    "            self.counter = 0\n",
    "        self.counter += 1\n",
    "        #print(self.counter)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-24-dfb24aa89017>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocessTrainTestArrays\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0mx_val\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocessTrainTestArrays\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_val\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m model.fit(x_train, y_train, epochs=30, batch_size=Batch_size, verbose=1, shuffle=False,\n",
      "\u001b[0;32m<ipython-input-21-492c59392131>\u001b[0m in \u001b[0;36mcreate_model\u001b[0;34m(X_train, Y_train)\u001b[0m\n\u001b[1;32m     11\u001b[0m                             batch_input_shape=(Batch_size,X_train.shape[1], X_train.shape[2])))\n\u001b[1;32m     12\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m     \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTimeDistributed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mactivation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'softmax'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/ass6ash/anaconda/lib/python3.6/site-packages/keras/models.py\u001b[0m in \u001b[0;36madd\u001b[0;34m(self, layer)\u001b[0m\n\u001b[1;32m    473\u001b[0m                           output_shapes=[self.outputs[0]._keras_shape])\n\u001b[1;32m    474\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 475\u001b[0;31m             \u001b[0moutput_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    476\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    477\u001b[0m                 raise TypeError('All layers in a Sequential model '\n",
      "\u001b[0;32m/Users/ass6ash/anaconda/lib/python3.6/site-packages/keras/engine/topology.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, **kwargs)\u001b[0m\n\u001b[1;32m    574\u001b[0m                                          '`layer.build(batch_input_shape)`')\n\u001b[1;32m    575\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_shapes\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 576\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_shapes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    577\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    578\u001b[0m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_shapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/ass6ash/anaconda/lib/python3.6/site-packages/keras/layers/wrappers.py\u001b[0m in \u001b[0;36mbuild\u001b[0;34m(self, input_shape)\u001b[0m\n\u001b[1;32m    148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    149\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mbuild\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m         \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    151\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_spec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mInputSpec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    152\u001b[0m         \u001b[0mchild_input_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#Implementing validation\n",
    "\n",
    "Batch_size=1\n",
    "start_time = time.time()\n",
    "\n",
    "x_train,y_train = processTrainTestArrays(X_train,Y_train)\n",
    "x_val,y_val = processTrainTestArrays(X_val,Y_val)\n",
    "model = create_model(x_train,y_train)\n",
    "\n",
    "model.fit(x_train, y_train, epochs=30, batch_size=Batch_size, verbose=1, shuffle=False,\n",
    "            validation_data = (x_val,y_val),callbacks=[checkpoint,ResetStatesCallback()])\n",
    "\n",
    "\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#Test with best model\n",
    "#Load weights and compile again\n",
    "print(\"\\n===========================\\nTime for testing\\n===========================\\n\")\n",
    "model.load_weights(hdf5FileName)\n",
    "nadam = optimizers.Nadam(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)\n",
    "model.compile(loss='categorical_crossentropy', optimizer=nadam, metrics=['accuracy'])\n",
    "print(\"Optimal weights loaded from file {}\".format(hdf5FileName))\n",
    "print(\"Model Successfully compiled with loaded weights\\n\")\n",
    "\n",
    "#Do same preprocessing for test data\n",
    "x_test,y_test = processTrainTestArrays(X_test,Y_test)\n",
    "loss,acc = model.evaluate(x_test,y_test,batch_size=Batch_size)\n",
    "print(\"Loss for testing = {} and Accuracy for testing = {}\".format(loss,acc))\n",
    "predicted = model.predict(x_test,batch_size=Batch_size)\n",
    "predicted = np.reshape(predicted,(predicted.shape[0],predicted.shape[2]))\n",
    "y_test = np.reshape(y_test,(y_test.shape[0],y_test.shape[2]))\n",
    "compute_metrics(predicted, y_test)    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
